# -*- coding: utf-8 -*-
"""
Created on Tue Nov  5 06:03:27 2019

@author: Salem

computes the minimum of the energy after applying a force
"""


import numpy as np
from numpy import linalg as la
import numpy.random as npr
import scipy as sp
import scipy.linalg as sla
import scipy.optimize as op

import matplotlib.pyplot as plt


#--------------------------------------------------------------------------
from pathlib import WindowsPath
from pathlib import Path
import os

cwd = Path(os.getcwd())

data_dir = cwd / "Elastic Shell Mesh Data"

#--------------------------------------------------------------------------

#======================================================================================================================================
# updates the mesh for different values of sharpening rate
#======================================================================================================================================
def update_all_meshes(start_index = 0, end_index=28):
    

    
    for SR in range(start_index, end_index + 1):
        print("working on mesh number: ", SR)
        update_mesh(SR)
        
    return
#========================================================================================================================================
#========================================================================================================================================


#======================================================================================================================================
# implements a change in prefered lengths and returns the new equlibrium position
#======================================================================================================================================
def update_mesh(sharpening_rate, load_dir=data_dir, save_dir = data_dir, fixed_distace=0.01):
    '''
    apply the force
    
    input: 
        vertices: Initial vertices
        edge_array: edge array
        fixed_distance: gives the distance after which the vertices are fixed in the minimization.
        
    '''
    
    vertices, edge_array, neibs = generate_beak_mesh(sharpening_rate)
    
    global num_of_verts, num_of_edges, dim
    num_of_verts = vertices.shape[0]; dim = 3
   
    
    # for bending energy. Not needed when solid
    #diherdralEdges = DihedralVertices(neibs, edges)  
    #edge_array = np.vstack((edges_array, diherdralEdges))
    
    
    num_of_edges = edge_array.shape[0]
    
    edge_lengths = np.sqrt(np.sum((vertices[edge_array[:,0]] - vertices[edge_array[:,1]])**2, axis=1))
  
    
    #connectivity dependent matrices that help in the calculation
    edge_mat1, edge_mat2 = make_edge_matrix(edge_array, num_of_verts)
    
    
    # get the adjacency and neighbor and edge map matrices. 
    adjMat, edgeMap = makeAdjacencyMatrix(edge_array)
      
    #bending_stiffness = 0.03
    #spring_consts = 10*np.hstack((  np.ones(edges.shape[0]), bending_stiffness*np.ones(diherdralEdges.shape[0]) ))/edge_lengths
    spring_consts = 10*(np.ones(edge_array.shape[0]))/edge_lengths
    
    #these vertices will be fixed to the wall
    fixed_vert_indices = np.where(  vertices[:,0] > max(vertices[:,0]) -  fixed_distace )[0]
    #the index of the displacement component after flattening for optimization
    fixed_disp_indices = []
    for index in fixed_vert_indices:
        for d in range(dim):
            fixed_disp_indices.append(dim*index + d)
           
    fixed_disp_indices = np.array(fixed_disp_indices)
       
    #vertices is not flatten here, but in the rigidityMat method
    dynMat = makeDynamicalMat(vertices, edge_mat1 , edge_mat2, spring_consts)
    
    force = get_applied_force(vertices)

    # force is not flatten here, at op.minimize it will be explicitly flattened 
    displacements = get_displacement(dynMat, force, fixed_disp_indices)
    new_verts = vertices + displacements
    
    new_edge_lengths = np.sqrt(np.sum((new_verts[edge_array[:,0]] - new_verts[edge_array[:,1]])**2, axis=1))
    strains = np.abs((new_edge_lengths - edge_lengths)/edge_lengths)
    
    np.savetxt(save_dir / ("new_vertices_" + str(sharpening_rate) + ".csv"), new_verts, delimiter=',')
    np.savetxt(save_dir / ("beak_displacements_" + str(sharpening_rate) + ".csv"), displacements, delimiter=',')
    np.savetxt(save_dir / ("beak_strains_" + str(sharpening_rate) + ".csv"), strains, delimiter=',')

        
    return 
#========================================================================================================================================
#========================================================================================================================================

    


 
#=========================================================================================================================
# generate the mesh that represents the beak shape
#=========================================================================================================================
def generate_beak_mesh(sharpening_rate, load_dir=data_dir):
    '''   
    given the beak dimensions (Lenth, Width, Depth) and curvature parameters, generate vertices
    and edges. Implement possibly bending energy as well (cross edges.)
    '''
    
    vertices = np.loadtxt(load_dir / ("beak_vertices_" + str(sharpening_rate) + ".csv"), delimiter=',')
    edge_array = np.loadtxt(load_dir / ("beak_edges_" + str(sharpening_rate) + ".csv"), dtype=int, delimiter=',') - 1
    #neighbor_array = np.genfromtxt(load_dir / ("neighbor_list_" + str(sharpening_rate) + ".csv"), dtype=int, delimiter=',') - 1
    
    #neib_list = []
    #for neibs in neighbor_array:
    #    neib_list.append(list(neibs[neibs >= 0]))

        
    return vertices, edge_array, [] #, neib_list
#=========================================================================================================================
#=========================================================================================================================
 




#================================================================================================================================
# gets the force applied at the tip from its magnitude and the range near the tip where it will be applied
#===============================================================================================================================
def get_applied_force(vertices, force_magnitude=0.05, force_range = 0.05, force_direction = np.array([1, 0.0, 1])):
    '''
    get_applied_forces(vertices, magnitude, force_range):
        for vertices near the tip, as determined from the force range, apply a force with 
        the given magnitude. The force is zero for the rest of the vertices.
        
        The shape of the force is the transpose of vertices shapes.
    
    '''
    
    #return np.ones_like(vertices)
    
    force = np.zeros_like(vertices)
     
    #find vertices that are within range of the tip
    mask = vertices[:,0] < force_range
    
    #apply force
    force[mask] = force_magnitude*force_direction
    
    return force
#===================================================================================================================================
#===================================================================================================================================
 
    
#==================================================================================================================================
    
# Calculates the energy given the preferred lengths and starting position
#==================================================================================================================================
    
def get_displacement(dynMat, force, fixed_disp_indices):
    """
    find the displacement by minimizing the energy
    
    dynMat: is the dynamical matrix
    
    fixed_disps: displacements that are fixed, not variable when minimizing
    
    fixed_indices: indices of the fixed vertices
    """
    
    init_displacement = 0.00*npr.rand(num_of_verts, dim)
    
        #return energy(init_displacement, energy_mat, edge_vector,rigidity_vector)
    res = op.minimize(energy, init_displacement, method='BFGS', args=(force.flatten(), dynMat, fixed_disp_indices), jac=energy_grad)
                    
      # method='Newton-CG',  hess=energy_hess,  # , options={'xtol': 1e-8, 'disp': False})
        
    displacement = res.x.reshape(num_of_verts, dim)

        
    return displacement
#=======================================================================================================================================


#=======================================================================================================================================
# Calculates the energy after the lengths have changed
#======================================================================================================================================
def energy(u, force, dynMat,  fixed_indices):
    """
    This calculates the energy as a function of the displacement from the precious 
    equlibrium before the lengths have changed. 
    which is the matrix y
    
    u: The displacement of the vertices. 
    
    dynMat: dynamical matrix
    
    lengthMat: encodes the change in lengths and is given by
            M =0.5* np.dot(spring_consts*sqr_length_change, rigidityMat)
    
    fixed_disps: displacements that are fixed, not variable when minimizing
    
    fixed_indices: indices of the fixed vertices
    """
    
    u[fixed_indices] = 0
    
    return 0.5*np.dot(np.dot(u, dynMat), u) -  np.dot(force, u)
#=========================================================================================================================================
    

#========================================================================================================================================
# Calculates the energy given the preferred lengths and starting position
#========================================================================================================================================
def energy_grad(u, force, dynMat,  fixed_indices):
    """
    This calculates the energy gradient as a function of the displacement from the precious 
    equlibrium before the lengths have changed. 
    which is the matrix y
    
    u: The displacement of the vertices. 
    
    dynMat: dynamical matrix
    
    lengthMat: encodes the change in lengths and is given by
            M =0.5* np.dot(spring_consts*sqr_length_change, rigidityMat)
    
    fixed_disps: displacements that are fixed, not variable when minimizing
    
    fixed_indices: indices of the fixed vertices
    """
    
    u[fixed_indices] = 0
    return np.dot(u, dynMat) -  force
#=======================================================================================================================================   
    
#=======================================================================================================================================
# Calculates the energy given the preferred lengths and starting position
#=======================================================================================================================================
def energy_hess(u, force, dynMat,  fixed_indices):
    """
    This calculates the energy Hessian as a function of the displacement from the precious 
    equlibrium before the lengths have changed. 
    which is the matrix y
    
    u: The displacement of the vertices. 
    
    dynMat: dynamical matrix
    
    lengthMat: encodes the change in lengths and is given by
            M =0.5* np.dot(spring_consts*sqr_length_change, rigidityMat)
    
    fixed_disps: displacements that are fixed, not variable when minimizing
    
    fixed_indices: indices of the fixed vertices
    """
    
    u[fixed_indices] = 0
    return dynMat
#========================================================================================================================================= 

   
#====================================================================================================================================
# returns the Rigidity Matrix as an array
#====================================================================================================================================
def makeDynamicalMat(verts, edgeMat1 , edgeMat2, springK):
    """
    makeDynamicalMat(verts, edgeArray, numOfVerts=-1, numOfEdges=-1):
        Takes in the edgeArray then finds dynamical matrix. The dynamical matrix
        help in calculating the potential energy of a displacement u which has 
        size = 2 numOfVerts. The energy is given by E[u] = u.T D u.
        
    Example in 2D: 
            (verts, edges) = squareLattice(2, randomize=False); 
             makeDynamicalMat(edgeArray=edges, RigidityMat=R)
        Out: array([[ 2.,  1.,  0.,  0., -1.,  0., -1., -1.],
       [ 1.,  2.,  0., -1.,  0.,  0., -1., -1.],
       [ 0.,  0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0., -1.,  0.,  1.,  0.,  0.,  0.,  0.],
       [-1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  1.,  0., -1.],
       [-1., -1., -1.,  0.,  0.,  0.,  2.,  1.],
       [-1., -1.,  0.,  0.,  0., -1.,  1.,  2.]])
    """
    
    
    
    RigidityMat = makeRigidityMat(verts, edgeMat1 , edgeMat2) 
    
    #return the dynamical matrix
    return np.dot(np.dot(RigidityMat.transpose(), np.diag(springK**2)), RigidityMat)
    
#=====================================================================================================================================

#===================================================================================================================================
# returns the Rigidity Matrix as an array
#===================================================================================================================================
def makeRigidityMat(verts, edgeMat1 , edgeMat2):
    """
    makeRigidityMat(verts, edgeArray, numOfVerts=-1, numOfEdges=-1,method):
        Takes in the edgeArray then finds Rigidity matrix. The rigidity matrix helps
        to find the bond stretching to linear order in displacement u which has 
        size = 2 numOfVerts. Bond stretchings are equal to 
        dl_e = R_ei * u_i, where i is summed over.
        
        The method parameter desides how the rigidity matrix will be computed. When method = 1
        the edgeMatrices will be used, which is useful when the vertex positions are minimized over. 
        verts should be flattened when this method is used
        
    Example1: 
            sq = squareLattice(2, randomize=False); 
            edgeMat1= makeEdgeMatrix1(sq[1])
            edgeMat2 = makeEdgeMatrix2(sq[1])
            R = makeRigidityMat(sq[0].flatten(), edgeMat1=edgeMat1, edgeMat2=edgeMat2)
            R 
        Out: array([[ 0., -1.,  0.,  1.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0., -1.,  0.,  1.],
       [-1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.],
       [ 0.,  0., -1.,  0.,  0.,  0.,  1.,  0.],
       [-1., -1.,  0.,  0.,  0.,  0.,  1.,  1.]])
    
    Example2:
        (verts, edges) = squareLattice(2, randomize=False); 
        edgeMat1 = 
            R = makeRigidityMat(verts, edges) ;R
      array([[ 0., -1.,  0.,  1.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0., -1.,  0.,  1.],
       [-1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.],
       [ 0.,  0., -1.,  0.,  0.,  0.,  1.,  0.],
       [-1., -1.,  0.,  0.,  0.,  0.,  1.,  1.]])       
    """

     
    RMat = np.dot(edgeMat1, verts.flatten())
    RMat = np.multiply(edgeMat1.transpose(), RMat).transpose()
    return np.dot(edgeMat2, RMat)

#=====================================================================================================================================    

#===================================================================================================================================
# Computes three auxilary edge matrices needed in computation of energy and its gradient
#===================================================================================================================================  
def make_edge_matrix(edges, num_verts, dim=3):
    '''
    computes three auxilary edge matrices needed in computation of energy and its gradient
    
    '''
    
    
    num_edges = edges.shape[0]
    
    edge_matrix1 = np.zeros((dim*num_edges, dim*num_verts))
    
    edge_matrix2 = np.zeros((num_edges, dim*num_edges))
    
    #edge_matrix3 = np.zeros((num_edges, dim*num_verts))
    
    for e, edge in enumerate(edges):
        for i in range(dim):
            edge_matrix1[dim*e + i, dim * edge[0] + i] = 1
            edge_matrix1[dim*e + i, dim * edge[1] + i] = -1
            
            edge_matrix2[e, dim * e + i] = 1
            
            #edge_matrix3[e, dim*edge[0] + i] = 1
            #edge_matrix3[e, dim*edge[1] + i] = 1
         
            
    return edge_matrix1, edge_matrix2#, edge_matrix3
#===================================================================================================================================
#===================================================================================================================================  

#======================================================================================================================================
# returns the adjacency matrix as an array
#======================================================================================================================================
def makeAdjacencyMatrix(edgeArray, numOfVerts=-1):
    """
    makeAdjacencyMatrix(edgeList):
        Takes in the edgeArray then converts it to a list, which has elements of the form [vert1, vert2] and finds the (2 numOfVerts x 2 numOfVerts) 
        adjacency matrix.
        
      Example: verts, edges = squareLattice(2)
            EdgeMat = makeEdgeMatrix(edges)
       Out:  array([[ 0.,  1.,  1.,  1.],
       [ 1.,  0.,  0.,  1.],
       [ 1.,  0.,  0.,  1.],
       [ 1.,  1.,  1.,  0.]])
    """
    
    if numOfVerts < 1:
        numOfVerts = len(set(list(edgeArray.flatten())))

    adjacencyMat = np.zeros((numOfVerts, numOfVerts), dtype=int) 
    edgeMap = np.zeros((numOfVerts, numOfVerts), dtype=np.int) 
    
    for eIndex, edge in enumerate(edgeArray):
        adjacencyMat[edge[1], edge[0]] = adjacencyMat[edge[0], edge[1]]= 1
        edgeMap[edge[1], edge[0]] = edgeMap[edge[0], edge[1]]= eIndex
        
    return adjacencyMat, edgeMap
#=====================================================================================================================================   
    

#=================================================================================
#find the neighbor list from the edge list
#=================================================================================
def DihedralVertices(Neibs, EList):
    ''' 
        Calculates the dihedral vertices for the triangulation from edge list by first 
        finding the edge list.
        return (numEdges, 2) array containing the indices of the two 
        vertices corresponding to the triangles that include the edge
        Neibs can be an array or list
    '''
    #Neibs = NeibsFromEdges(numVerts, EList)[0]
    
    numEdges = EList[:, 0].size
    
    FaceVerts = [] #np.zeros((numEdges, 2), dtype=int)
    
    #loop over all the edges
    for i in range(numEdges): 
        #find the two triangles intersecting at the edge by finding common neighbors
        
        intersection = np.intersect1d(Neibs[EList[i,0]], Neibs[EList[i,1]]) 
        
        if intersection.size==2:
            FaceVerts.append(intersection)
    
    return np.array(FaceVerts)
#=================================================================================
    

#=================================================================================
#find the neighbor list from the edge list
#=================================================================================
def NeibsFromEdges(numVerts, EdgeList):
    '''
        Returns a Neighbor list and a map between the edges list and Neighbor list.
        This returns two python lists because the cap rows will be bigger. 
        Can be easily adjusted to exclude cap rows and be returned as an array
    '''

    
    #for each row NeibList gives the neighbor indices of the vertex correspoding to the row index                       
    NeibList = [[]]*numVerts
    
    #For each index in a row MapNL will point to the correct edge index               
    MapNL = [[]]*numVerts
               
    for Vindx in np.nditer(np.arange(numVerts)):
        for Eindx, edge in enumerate(EdgeList): 
            if edge[0] == Vindx:
                NeibList[Vindx] = [NeibList[Vindx], edge[1]]
                MapNL[Vindx] = [MapNL[Vindx], Eindx]
            elif edge[1] == Vindx:
                NeibList[Vindx] = [NeibList[Vindx], edge[0]]
                MapNL[Vindx] = [MapNL[Vindx], edge[0]]
                
            NeibList[Vindx] = flatten(NeibList[Vindx])
            MapNL[Vindx] = flatten(MapNL[Vindx])
                
    return (NeibList, MapNL)
#=================================================================================
    

#=================================================================================
#flatten a list 
#=================================================================================
def flatten(lis):
    """Given a list, possibly nested to any level, return it flattened."""
    new_lis = []
    for item in lis:
        if type(item) == type([]):
            new_lis.extend(flatten(item))
        else:
            new_lis.append(item)
    return new_lis
#=================================================================================